"""

"""


# Built-in
import os

# Libs
import numpy as np
from tqdm import tqdm

# Own modules
from mrs_utils import misc_utils


def pad_image(img, pad, mode='reflect'):
    """
    Symmetric pad pixels around images
    :param img: image to pad
    :param pad: list of #pixels pad around the image, if it is a scalar, it will be assumed to pad same number
                number of pixels around 4 directions
    :param mode: padding mode
    :return: padded image
    """
    if type(pad) is not list:
        pad = [pad for i in range(4)]
    assert len(pad) == 4
    if len(img.shape) == 2:
        return np.pad(img, ((pad[0], pad[1]), (pad[2], pad[3])), mode)
    else:
        h, w, c = img.shape
        pad_img = np.zeros((h + pad[0] + pad[1], w + pad[2] + pad[3], c))
        for i in range(c):
            pad_img[:, :, i] = np.pad(img[:, :, i], ((pad[0], pad[1]), (pad[2], pad[3])), mode)
    return pad_img


def crop_image(img, y, x, h, w):
    """
    Crop the image with given top-left anchor and corresponding width & height
    :param img: image to be cropped
    :param y: height of anchor
    :param x: width of anchor
    :param h: height of the patch
    :param w: width of the patch
    :return:
    """
    if len(img.shape) == 2:
        return img[y:y+w, x:x+h]
    else:
        return img[y:y+w, x:x+h, :]


def make_grid(tile_size, patch_size, overlap):
    """
    Extract patches at fixed locations. Output coordinates for Y,X as a list (not two lists)
    :param tile_size: size of the tile (input image)
    :param patch_size: size of the output patch
    :param overlap: #overlapping pixels
    :return:
    """
    max_h = tile_size[0] - patch_size[0]
    max_w = tile_size[1] - patch_size[1]
    if max_h > 0 and max_w > 0:
        h_step = np.ceil(tile_size[0] / (patch_size[0] - overlap))
        w_step = np.ceil(tile_size[1] / (patch_size[1] - overlap))
    else:
        h_step = 1
        w_step = 1
    patch_grid_h = np.floor(np.linspace(0, max_h, h_step)).astype(np.int32)
    patch_grid_w = np.floor(np.linspace(0, max_w, w_step)).astype(np.int32)

    y, x = np.meshgrid(patch_grid_h, patch_grid_w)
    return list(zip(y.flatten(), x.flatten()))


def patch_block(block, pad, grid_list, patch_size, return_coord=False):
    """
    make a data block into patches
    :param block: data block to be patched, shold be h*w*c
    :param pad: #pixels to pad around
    :param grid_list: list of grids
    :param patch_size: size of the patch
    :param return_coord: if True, coordinates of x and y will be returned
    :return: yields patches or as well as x and y coordinates
    """
    # pad image first if it is necessary
    if pad > 0:
        block = pad_image(block, pad)
    # extract images
    for y, x in grid_list:
        patch = crop_image(block, y, x, patch_size[0], patch_size[1])
        if return_coord:
            yield patch, y, x
        else:
            yield patch


def unpatch_block(blocks, tile_dim, patch_size, tile_dim_output=None, patch_size_output=None, overlap=0):
    """
    Unpatch a block, set tile_dim_output and patch_size_output to a proper number if padding exits
    :param blocks: data blocks, should be n*h*w*c
    :param tile_dim: input tile dimension, if padding exits should be h+2*pad, w+2*pad
    :param patch_size: input patch size
    :param tile_dim_output: output tile dimension
    :param patch_size_output: output patch dimension, if shrinking exits, should be h-2*pad, w-2*pad
    :param overlap: overlap of adjacent patches
    :return:
    """
    if tile_dim_output is None:
        tile_dim_output = tile_dim
    if patch_size_output is None:
        patch_size_output = patch_size
    _, _, _, c = blocks.shape
    image = np.zeros((tile_dim_output[0], tile_dim_output[1], c))
    image_cnt = np.zeros_like(image)
    for cnt, (corner_h, corner_w) in enumerate(make_grid(tile_dim, patch_size, overlap)):
        image[corner_h:corner_h + patch_size_output[0], corner_w:corner_w + patch_size_output[1], :] \
            += blocks[cnt, :, :, :]
        image_cnt[corner_h:corner_h + patch_size_output[0], corner_w:corner_w + patch_size_output[1], :] \
            += np.ones_like(blocks[cnt, :, :, :])
    return image / image_cnt


def patch_extractor(file_list, file_exts, patch_size, pad, overlap, save_path, force_run=False):
    """
    Extract the patches
    :param kwargs:
        file_list: list of lists of the files, can be generated by using collectionMaker.load_files()
        file_exts: extensions of the new files
    :return:
    """
    def extract_(file_list, file_exts, patch_size, pad, overlap, save_path):
        assert len(file_exts) == len(file_list[0])
        pbar = tqdm(file_list)
        record_file = open(os.path.join(save_path, 'file_list.txt'), 'w')
        for files in pbar:
            pbar.set_description('Extracting {}'.format(os.path.basename(files[0])))
            patch_list = []
            for f, ext in zip(files, file_exts):
                patch_list_ext = []
                img = misc_utils.load_file(f)
                grid_list = make_grid(np.array(img.shape[:2]) + 2 * pad, patch_size, overlap)
                # extract images
                for patch, y, x in patch_block(img, pad, grid_list, patch_size, return_coord=True):
                    patch_name = '{}_y{}x{}.{}'.format(os.path.basename(f).split('.')[0], int(y), int(x), ext)
                    patch_name = os.path.join(save_path, patch_name)
                    misc_utils.save_file(patch_name, patch.astype(np.uint8))
                    patch_list_ext.append(patch_name)
                patch_list.append(patch_list_ext)
            patch_list = misc_utils.rotate_list(patch_list)
            for items in patch_list:
                record_file.write('{}\n'.format(' '.join(items)))
        record_file.close()

    def check_finish(state_file):
        """
        check if state file exists
        :return: True if it has finished
        """
        state_exist = os.path.exists(state_file)
        if state_exist:
            with open(state_file, 'r') as f:
                a = f.readlines()
                if a[0].strip() == 'Finished':
                    return True
        return False

    # check if state file exists
    state_file = os.path.join(save_path, 'state.txt')
    state_exist = os.path.exists(state_file)
    # run the function if force run or haven't run before
    if force_run or state_exist == 0:
        print('Start running patch extractor')
        # write state log as incomplete
        with open(state_file, 'w') as f:
            f.write('Incomplete\n')
        extract_(file_list, file_exts, patch_size, pad, overlap, save_path)
        # write state log as complete
        with open(state_file, 'w') as f:
            f.write('Finished\n')
    else:
        # if haven't run before, run the process
        if not check_finish(state_file):
            extract_(file_list, file_exts, patch_size, pad, overlap, save_path)

        # write state log as complete
        with open(state_file, 'w') as f:
            f.write('Finished\n')


if __name__ == '__main__':
    pass
